{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "60KmTK7o6ppd"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GoAJlrr6LUyj"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-xjCcnrf10b"
      },
      "source": [
        "# Gemma - self extend Gemma\n",
        "\n",
        "This is one of the accompanying notebooks for the [Large Language Models with Keras](https://www.youtube.com/watch?v=TV7qCk1dBWA) technical session at Google I/O 2024 and it demonstrates how extend the context window of Gemma.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Workshops/[Gemma_1]Self_extend.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2hMRutWLgMe7"
      },
      "source": [
        "# Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **L4 GPU** or **A100 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AO7a1Q4Yyc9Z"
      },
      "source": [
        "# Installation\n",
        "\n",
        "Install Keras and KerasNLP with the Gemma model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GS7Kjiw58X9T"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m590.6/590.6 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m58.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m66.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m60.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m32.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for keras-nlp (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.2 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "!pip install -q -U git+https://github.com/keras-team/keras-nlp.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WEgg_OVIL2HY"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
        "\n",
        "import keras\n",
        "import keras_nlp\n",
        "\n",
        "# for reproducibility\n",
        "keras.utils.set_random_seed(42)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h5H7oE9MtnBA"
      },
      "outputs": [],
      "source": [
        "# Access to Gemma checkpoints on Kaggle\n",
        "\n",
        "# On Colab, set up you Kaggle credentials as Colab secrets and use this:\n",
        "\n",
        "from google.colab import userdata\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wo1xkzr62hXN"
      },
      "source": [
        "# Load the 2B model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dTULEQ29qidE"
      },
      "source": [
        " Load the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bu48vUnbQj0p"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.safetensors' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'model.safetensors.index.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
            "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n"
          ]
        }
      ],
      "source": [
        "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_1.1_instruct_2b_en\")\n",
        "gemma_lm.compile(sampler=keras_nlp.samplers.GreedySampler())  # this is the default"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mS7jstFSlGIO"
      },
      "source": [
        "# Add car dealership data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c3uuWiQ6HmH8"
      },
      "outputs": [],
      "source": [
        "retrieval_car = \"0. 2024 Google IO Car: Full of interesting demos and talks. Runs on AI, pretty good mileage. Ask for a discount at the dealership with this number 1234123.\"\n",
        "car_dealership_data = \"\"\"\n",
        "1. 2019 Chevrolet Silverado 1500 LT, Black: This workhorse is equipped with a powerful V8 engine and a spacious crew cab,offering a blend of power and comfort. Whether you're hauling heavy loads or cruising down the highway, this truck delivers a smooth and confident ride.\n",
        "2. 2021 Toyota Camry SE, Silver: The Toyota Camry is known for its reliability and fuel efficiency, making it a practical choice for daily commuting. This particular SE trim adds a sportier flair with its sleek exterior design and responsive handling.\n",
        "3. 2020 Ford F-150 XLT, Blue: The Ford F-150 is a perennial favorite in the truck market, and this XLT model features a comfortable interior and a variety of useful tech features. With its powerful engine and capable towing capacity, it's ready for work or play.\n",
        "4. 2022 Honda Civic LX, White: The Honda Civic is a popular choice for those seeking a reliable and efficient sedan. This LX model is equipped with essential features and offers a comfortable and practical driving experience.\n",
        "5. 2023 Nissan Altima SR, Red: The Nissan Altima SR boasts a sporty appearance and a peppy engine, making it a fun-to-drive sedan. With its modern interior and available technology features, it offers both style and substance.\n",
        "6. 2018 Jeep Grand Cherokee Laredo, Green: The Jeep Grand Cherokee is known for its ruggedness and off-road capability, making it a popular choice for adventure seekers. This Laredo model offers a comfortable ride and plenty of cargo space for all your gear.\n",
        "7. 2020 Ram 1500 Big Horn, Gray: The Ram 1500 Big Horn is a well-equipped truck with a powerful engine and a spacious cabin. It's a versatile workhorse that's also comfortable for everyday driving.\n",
        "8. 2019 Subaru Outback Limited, Brown: The Subaru Outback is a popular choice for those seeking a vehicle with all-wheel drive and plenty of cargo space. This Limited model features a luxurious interior and a variety of advanced safety features.\n",
        "9. 2021 Hyundai Tucson SEL, Silver: The Hyundai Tucson is a stylish and practical compact SUV with a comfortable interior and a smooth ride. This SEL model is equipped with a variety of modern features to enhance your driving experience.\n",
        "10. 2023 Kia Sportage EX, Blue: The Kia Sportage is a popular SUV with a stylish exterior and a comfortable interior. This EX model is equipped with a range of features that make it a great choice for families or adventurers.\n",
        "11. 2017 Ford Escape SE, White: The Ford Escape is a versatile compact SUV with a comfortable ride and a spacious interior. This SE model is equipped with a range of features that make it a great choice for families or adventurers.\n",
        "12. 2022 Chevrolet Equinox LT, Red: The Chevrolet Equinox is a popular SUV with a stylish exterior and a comfortable interior. This LT model is equipped with a range of features that make it a great choice for families or adventurers.\n",
        "13. 2020 Toyota RAV4 XLE, Black: The Toyota RAV4 is a reliable and fuel-efficient SUV that's popular for its practicality and versatility. This XLE model is equipped with a variety of features to enhance your driving experience.\n",
        "14. 2018 Honda CR-V EX-L, Gray: The Honda CR-V is a perennial favorite in the SUV market, known for its reliability and comfortable ride. This EX-L model offers a luxurious interior and a range of advanced features.\n",
        "15. 2021 Mazda CX-5 Touring, Brown: The Mazda CX-5 is a stylish and sporty SUV with a responsive handling and a comfortable ride. This Touring model is equipped with a variety of modern features to enhance your driving experience.\n",
        "16. 2023 Nissan Rogue SL, Silver: The Nissan Rogue SL is a top-of-the-line SUV with a luxurious interior and a range of advanced features. It's a great choice for those seeking a spacious and comfortable vehicle with all the bells and whistles.\n",
        "17. 2019 Dodge Durango RTT, Blue: The Dodge Durango RTT is a powerful and spacious SUV with a range of performance-oriented features. It's a great choice for those seeking a vehicle that can handle both everyday driving and adventurous off-road excursions.\n",
        "18. 2020 GMC Acadia SLE, White: The GMC Acadia is a spacious and versatile SUV. This SLE model is equipped with a variety of modern features to enhance your driving experience.\n",
        "19. 2018 Jeep Wrangler Sahara, Green: The Jeep Wrangler is an iconic off-road vehicle with a removable top and a rugged design. This Sahara model is equipped with a range of features that make it comfortable for everyday driving as well as off-road adventures.\n",
        "20. 2021 Ford Bronco Sport Badlands, Orange: The Ford Bronco Sport Badlands is a rugged and capable SUV designed for off-road adventures. It's equipped with a range of features that make it ready to tackle any terrain.\n",
        "21. 2022 Toyota 4Runner TRD Off-Road, Gray: The Toyota 4Runner is a legendary off-road vehicle with a rugged design and a powerful engine. This TRD Off-Road model is equipped with a range of features that make it ready for any adventure.\n",
        "22. 2020 Chevrolet Tahoe LT, Black: The Chevrolet Tahoe is a spacious and versatile SUV with a comfortable interior and a powerful engine. This LT model is equipped with a range of features that make it a great choice for families or adventurers.\n",
        "23. 2019 GMC Yukon Denali, White: The GMC Yukon Denali is a luxurious and powerful SUV with a spacious interior and a range of premium features. It's a great choice for those seeking a comfortable and stylish vehicle with plenty of power.\n",
        "24. 2021 Cadillac Escalade Premium Luxury, Blue: The Cadillac Escalade is a flagship SUV with a luxurious interior and a powerful engine. This Premium Luxury model is equipped with a range of advanced features that make it a true statement of luxury and style.\n",
        "25. 2022 Lincoln Navigator Reserve, Black: The Lincoln Navigator is a luxurious and spacious SUV. This Reserve model is equipped with a range of premium features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
        "26. 2023 Lexus LX 600 Ultra Luxury, Brown: The Lexus LX 600 Ultra Luxury is a top-of-the-line SUV with a luxurious interior and a range of advanced features. It's a great choice for those seeking a vehicle that combines opulence, technology, and off-road capability.\n",
        "27. 2020 BMW X5 xDrive40i, Silver: The BMW X5 is a performance-oriented SUV with a powerful engine and agile handling. This xDrive40i model is equipped with a range of features that make it a great choice for those seeking a luxurious and sporty driving experience.\n",
        "28. 2019 Mercedes-Benz GLE 450 4MATIC, White: The Mercedes-Benz GLE is a stylish and luxurious SUV with a comfortable interior and a powerful engine. This 450 4MATIC model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
        "29. 2021 Audi Q7 Premium Plus, Blue: The Audi Q7 is a spacious and luxurious SUV with a comfortable interior and a powerful engine. This Premium Plus model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
        "30. 2023 Porsche Cayenne E-Hybrid, Red: The Porsche Cayenne is a performance-oriented SUV with a powerful engine and agile handling. This E-Hybrid model is equipped with a range of features that make it a great choice for those seeking a luxurious and sporty driving experience.\n",
        "31. 2022 Land Rover Range Rover Velar R-Dynamic SE, Black: The Land Rover Range Rover Velar is a stylish and luxurious SUV with a comfortable interior and a powerful engine. This R-Dynamic SE model is equipped with a range of advanced features that make it a great choice for those seeking a refined and comfortable driving experience.\n",
        "32. 2020 Tesla Model Y Long Range, White: The Tesla Model Y is a popular electric SUV with a spacious interior, impressive range, and advanced technology features. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
        "33. 2018 Nissan Leaf SL Plus, Gray: The Nissan Leaf is a popular electric car with a comfortable ride, spacious interior, and impressive fuel efficiency. This SL Plus model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
        "34. 2021 Chevrolet Bolt EV Premier, Blue: The Chevrolet Bolt EV is a popular electric car with a fun driving experience, spacious interior, and impressive range. This Premier model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
        "35. 2023 Hyundai Kona Electric Limited, Red: The Hyundai Kona Electric is a stylish and practical electric SUV. This Limited model is equipped with a range of features that make it a great choice for those seeking a sustainable and stylish vehicle.\n",
        "36. 2022 Kia Niro EV EX Premium, Black: The Kia Niro EV is a popular electric car with a comfortable ride, spacious interior, and impressive fuel efficiency. This EX Premium model is equipped with a range of features that make it a great choice for those seeking a practical and affordable electric vehicle.\n",
        "37. 2020 Volkswagen ID.4 Pro, Silver: The Volkswagen ID.4 is a popular electric SUV with a spacious interior, modern features, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
        "38. 2019 Tesla Model 3 Long Range, White: The Tesla Model 3 is a popular electric sedan with a sleek design, impressive range, and advanced technology features. It's a great choice for those seeking a sustainable and innovative vehicle with a sporty driving experience.\n",
        "39. 2021 Ford Mustang Mach-E Premium, Red: The Ford Mustang Mach-E is a popular electric SUV with a powerful performance, stylish design, and modern features. It's a great choice for those seeking a sustainable and sporty vehicle with plenty of space and comfort.\n",
        "40. 2023 Hyundai IONIQ 5 SE, Blue: The Hyundai IONIQ 5 is a popular electric car with a spacious interior, modern features, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
        "41. 2022 Kia EV6 GT-Line, Black: The Kia EV6 is a popular electric SUV with a stylish exterior, modern interior, and impressive range. It's a great choice for those seeking a sustainable and innovative vehicle with plenty of space and comfort.\n",
        "42. 2020 Nissan Versa S, White: The Nissan Versa is a budget-friendly sedan that offers a surprising amount of interior space and fuel efficiency. This S model is a great choice for those seeking a practical and affordable vehicle for everyday commuting.\n",
        "43. 2019 Mitsubishi Mirage ES, Red: The Mitsubishi Mirage is a compact car known for its fuel efficiency and affordability. This ES model offers a few extra features and a bit more style to make your daily commute more enjoyable.\n",
        "44. 2021 Chevrolet Spark LS, Black: The Chevrolet Spark is a small and nimble city car that's easy to park and maneuver in tight spaces. This LS model is a great choice for those seeking an affordable and practical vehicle for city driving.\n",
        "45. 2022 Kia Rio LX, Gray: The Kia Rio is a stylish and affordable subcompact car with a surprisingly spacious interior. This LX model is a great choice for those seeking a practical and stylish vehicle for everyday commuting.\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a5eYP2UBHsH9"
      },
      "outputs": [],
      "source": [
        "__START_TURN_USER__ = \"<start_of_turn>user\\n\"\n",
        "__START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n",
        "__END_TURN__ = \"<end_of_turn>\\n\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yOTBqH3nA0xX"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt has 2716 tokens\n",
            "====================\n",
            "Gemma output:  The discount code is 1234123.\n",
            "====================\n",
            "Actual discount code is 1234123\n",
            "\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "prompt_postfix = (\n",
        "    \"\\nWhat is the discount code for the 2024 Google IO Car?\"\n",
        "    + __END_TURN__\n",
        "    + __START_TURN_MODEL__\n",
        "    + \" The discount code is \"\n",
        ")\n",
        "prompt = __START_TURN_USER__ + retrieval_car + car_dealership_data * 1 + prompt_postfix\n",
        "tokenized = gemma_lm.preprocessor.tokenizer.tokenize(prompt)\n",
        "print(f\"Prompt has {len(tokenized)} tokens\")\n",
        "gemma_page = gemma_lm.generate(prompt, max_length=len(tokenized) + 20)\n",
        "gemma_page = gemma_page.split(__START_TURN_MODEL__)[1]\n",
        "print(\"=\" * 20)\n",
        "print(\"Gemma output:\", gemma_page)\n",
        "print(\"=\" * 20)\n",
        "print(f\"Actual discount code is 1234123\")\n",
        "print(\"\\n\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQXvY2UDLsyj"
      },
      "source": [
        "# Create self-extend attention call"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "znfYstRD9m6U"
      },
      "outputs": [],
      "source": [
        "from keras import ops\n",
        "\n",
        "GROUP_SIZE = 8\n",
        "WINDOW_SIZE = 1024\n",
        "\n",
        "\n",
        "def extend_compute_attention(\n",
        "    self,\n",
        "    query,\n",
        "    key,\n",
        "    value,\n",
        "    attention_mask,\n",
        "    training=False,\n",
        "):\n",
        "    # Shapes for attention.\n",
        "    head_dim, kv_heads = self.head_dim, self.num_key_value_heads\n",
        "    batch_size, query_len = ops.shape(query)[:2]\n",
        "    key_len = ops.shape(key)[1]\n",
        "\n",
        "    def apply_rope(inputs, positions):\n",
        "        outputs = self.rope_layer(inputs, positions=positions)\n",
        "        # Gemma has a different shape for Rope outputs\n",
        "        outputs = ops.stack(ops.split(outputs, 2, axis=-1), axis=-1)\n",
        "        return ops.reshape(outputs, ops.shape(inputs))\n",
        "\n",
        "    def attn(query, key):\n",
        "        query *= ops.cast(1 / ops.sqrt(head_dim), dtype=query.dtype)\n",
        "        new_shape = (batch_size, query_len, kv_heads, -1, head_dim)\n",
        "        query = ops.reshape(query, new_shape)\n",
        "        return ops.einsum(\"btkgh,bskh->bkgts\", query, key)\n",
        "\n",
        "    # Calculate normal positions.\n",
        "    cache_index = ops.sum(attention_mask[0, 0, :]) - 1\n",
        "    query_positions = ops.arange(query_len) + cache_index\n",
        "    key_positions = ops.arange(key_len)\n",
        "\n",
        "    # Normal RoPE and compute logits.\n",
        "    local_query = apply_rope(query, query_positions)\n",
        "    local_key = apply_rope(key, key_positions)\n",
        "    local_logits = attn(local_query, local_key)\n",
        "\n",
        "    # Calculate grouped positions.\n",
        "    query_positions = query_positions // GROUP_SIZE\n",
        "    query_positions += WINDOW_SIZE - GROUP_SIZE // WINDOW_SIZE\n",
        "    key_positions = key_positions // GROUP_SIZE\n",
        "\n",
        "    # Grouped RoPE and compute logits.\n",
        "    grouped_query = apply_rope(query, query_positions)\n",
        "    grouped_key = apply_rope(key, key_positions)\n",
        "    grouped_logits = attn(grouped_query, grouped_key)\n",
        "\n",
        "    # Combine attention logits.\n",
        "    attn_mask = attention_mask[:, None, None, :, :]\n",
        "    # Flip the attn mask, cumsum, and flip back to get relative positions.\n",
        "    group_mask = ops.flip(ops.cumsum(ops.flip(attn_mask, -1), -1), -1)\n",
        "    # Take local attention where relative postions are less than window size.\n",
        "    local_mask = group_mask < WINDOW_SIZE\n",
        "    logits = ops.where(local_mask, local_logits, grouped_logits)\n",
        "\n",
        "    # Softmax and weighted value sum.\n",
        "    scores = ops.cast(self.softmax(logits, attn_mask), self.compute_dtype)\n",
        "    results = ops.einsum(\"bkgts,bskh->btkgh\", scores, value)\n",
        "    return ops.reshape(results, (batch_size, query_len, -1, head_dim))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2jnQoTupaFoF"
      },
      "source": [
        "# Patching code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1RtCV0dhaIKZ"
      },
      "outputs": [],
      "source": [
        "from types import MethodType\n",
        "\n",
        "for layer in gemma_lm.backbone.transformer_layers:\n",
        "    # Patch model rope call's with identity\n",
        "    # Our rope call's will happen inside of our attention computation\n",
        "    layer.attention._apply_rope = MethodType(lambda s, x, i: x, layer.attention)\n",
        "    # Patching model attention with self_extend version\n",
        "    layer.attention._compute_attention = MethodType(\n",
        "        extend_compute_attention, layer.attention\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A9mh1itbq3uF"
      },
      "outputs": [],
      "source": [
        "gemma_lm.built = False\n",
        "gemma_lm.generate_function = None\n",
        "keras.config.disable_traceback_filtering()\n",
        "gemma_lm.compile(sampler=keras_nlp.samplers.GreedySampler())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KlIsOfcykYwN"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"<start_of_turn>user\\nAnswer me: Hi there<end_of_turn>\\n<start_of_turn>model\\nHi there! 👋 It's great to hear from you. What would you like to talk about today?\""
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gemma_lm.generate(\n",
        "    __START_TURN_USER__ + \"Answer me: Hi there\" + __END_TURN__ + __START_TURN_MODEL__,\n",
        "    max_length=80,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xaRDrieuUA6a"
      },
      "source": [
        "# Test with SelfExtend Algorithm - 10k context window"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9O15mj27Soaq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt has 10639 tokens\n",
            "====================\n",
            "Gemma output:  The discount code is 1234123\n",
            "====================\n",
            "Actual discount code is 1234123\n",
            "\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "prompt_postfix = (\n",
        "    \"\\nWhat is the discount code for the 2024 Google IO Car?\"\n",
        "    + __END_TURN__\n",
        "    + __START_TURN_MODEL__\n",
        "    + \" The discount code is \"\n",
        ")\n",
        "prompt = __START_TURN_USER__ + retrieval_car + car_dealership_data * 4 + prompt_postfix\n",
        "tokenized = gemma_lm.preprocessor.tokenizer.tokenize(prompt)\n",
        "print(f\"Prompt has {len(tokenized)} tokens\")\n",
        "gemma_page = gemma_lm.generate(prompt, max_length=len(tokenized) + 20)\n",
        "gemma_page = gemma_page.split(__START_TURN_MODEL__)[1].split(\".\")[0]\n",
        "print(\"=\" * 20)\n",
        "print(\"Gemma output:\", gemma_page)\n",
        "print(\"=\" * 20)\n",
        "print(f\"Actual discount code is 1234123\")\n",
        "print(\"\\n\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzKsCGIN0yX5"
      },
      "source": [
        "# What's next\n",
        "\n",
        "In this tutorial, you learned how to extend the context window of Gemma, using Keras on JAX.\n",
        "\n",
        "Here are a few suggestions for what else to learn, about Keras and JAX:\n",
        "* [Distributed training with Keras 3](https://keras.io/guides/distribution/).\n",
        "* [Writing a custom training loop for a Keras model in JAX](https://keras.io/guides/writing_a_custom_training_loop_in_jax/).\n",
        "\n",
        "And a couple of more basic Gemma tutorials:\n",
        "\n",
        "* [Get started with Keras Gemma](https://ai.google.dev/gemma/docs/get_started).\n",
        "* [Finetune the Gemma model on GPU](https://ai.google.dev/gemma/docs/lora_tuning)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_1]Self_extend.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
